{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Za8-Nr5k11fh"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "Eq10uEbw0E4l"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "06ndLauQxiQm"
      },
      "source": [
        "# Train Your Own Model and Convert It to TFLite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dtav_aq2xh6n"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_lite/tflite_c04_exercise_convert_model_to_tflite_solution.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/blob/master/courses/udacity_intro_to_tensorflow_lite/tflite_c04_exercise_convert_model_to_tflite_solution.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ka96-ajYzxVU"
      },
      "source": [
        "This notebook uses the [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset which contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 by 28 pixels), as seen here:\n",
        "\n",
        "<table>\n",
        "  <tr><td>\n",
        "    <img src=\"https://tensorflow.org/images/fashion-mnist-sprite.png\"\n",
        "         alt=\"Fashion MNIST sprite\"  width=\"600\">\n",
        "  </td></tr>\n",
        "  <tr><td align=\"center\">\n",
        "    <b>Figure 1.</b> <a href=\"https://github.com/zalandoresearch/fashion-mnist\">Fashion-MNIST samples</a> (by Zalando, MIT License).<br/>\u0026nbsp;\n",
        "  </td></tr>\n",
        "</table>\n",
        "\n",
        "Fashion MNIST is intended as a drop-in replacement for the classic [MNIST](http://yann.lecun.com/exdb/mnist/) dataset—often used as the \"Hello, World\" of machine learning programs for computer vision. The MNIST dataset contains images of handwritten digits (0, 1, 2, etc.) in a format identical to that of the articles of clothing we'll use here.\n",
        "\n",
        "This uses Fashion MNIST for variety, and because it's a slightly more challenging problem than regular MNIST. Both datasets are relatively small and are used to verify that an algorithm works as expected. They're good starting points to test and debug code.\n",
        "\n",
        "We will use 60,000 images to train the network and 10,000 images to evaluate how accurately the network learned to classify images. You can access the Fashion MNIST directly from TensorFlow. Import and load the Fashion MNIST data directly from TensorFlow:"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rjOAfhgd__Sp"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pfyZKowNAQ4j"
      },
      "outputs": [],
      "source": [
        "# TensorFlow and tf.keras\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "# Helper libraries\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import pathlib\n",
        "\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tadPBTEiAprt"
      },
      "source": [
        "# Download Fashion MNIST Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jmSkLCyRKqKB"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "tfds.disable_progress_bar()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XcNwi6nFKneZ"
      },
      "outputs": [],
      "source": [
        "splits, info = tfds.load('fashion_mnist', with_info=True, as_supervised=True, \n",
        "                         split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'])\n",
        "\n",
        "(train_examples, validation_examples, test_examples) = splits\n",
        "\n",
        "num_examples = info.splits['train'].num_examples\n",
        "num_classes = info.features['label'].num_classes"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-eAv71FRm4JE"
      },
      "outputs": [],
      "source": [
        "class_names = ['T-shirt_top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
        "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hXe6jNokqX3_"
      },
      "outputs": [],
      "source": [
        "with open('labels.txt', 'w') as f:\n",
        "  f.write('\\n'.join(class_names))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iubWCThbdN8K"
      },
      "outputs": [],
      "source": [
        "IMG_SIZE = 28"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZAkuq0V0Aw2X"
      },
      "source": [
        "# Preprocessing data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_5SIivkunKCC"
      },
      "source": [
        "## Preprocess"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BwyhsyGydHDl"
      },
      "outputs": [],
      "source": [
        "def format_example(image, label):\n",
        "  image = tf.cast(image, tf.float32)\n",
        "  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))\n",
        "  image = image / 255.0\n",
        "  return image, label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HAlBlXOUMwqe"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 32"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JM4HfIJtnNEk"
      },
      "source": [
        "## Create a Dataset from images and labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uxe2I3oxLDhq"
      },
      "outputs": [],
      "source": [
        "train_batches = train_examples.cache().shuffle(num_examples//4).batch(BATCH_SIZE).map(format_example).prefetch(1)\n",
        "validation_batches = validation_examples.cache().batch(BATCH_SIZE).map(format_example).prefetch(1)\n",
        "test_batches = test_examples.cache().batch(1).map(format_example)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M-topQaOm_LM"
      },
      "source": [
        "# Building the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kDqcwksFB1bh"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.Sequential([\n",
        "  tf.keras.layers.Conv2D(16, 3, activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1)),\n",
        "  tf.keras.layers.MaxPooling2D(),\n",
        "  tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
        "  tf.keras.layers.Flatten(),\n",
        "  tf.keras.layers.Dense(64, activation='relu'),\n",
        "  tf.keras.layers.Dense(10)\n",
        "])\n",
        "\n",
        "model.compile(\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    optimizer='adam',\n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zEMOz-LDnxgD"
      },
      "source": [
        "## Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1fk2faPsjqfU"
      },
      "outputs": [],
      "source": [
        "validation_batches"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DGJe_CNvjnhT"
      },
      "outputs": [],
      "source": [
        "train_batches, "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JGlNoRtzCP4_"
      },
      "outputs": [],
      "source": [
        "model.fit(train_batches, \n",
        "          epochs=10,\n",
        "          validation_data=validation_batches)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TZT9-7w9n4YO"
      },
      "source": [
        "# Exporting to TFLite"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9dq78KBkCV2_"
      },
      "outputs": [],
      "source": [
        "export_dir = 'saved_model/1'\n",
        "tf.saved_model.save(model, export_dir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EDGiYrBdE6fl"
      },
      "outputs": [],
      "source": [
        "optimization = tf.lite.Optimize.DEFAULT"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RbcS9C00CzGe"
      },
      "outputs": [],
      "source": [
        "# Convert the model.\n",
        "converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)\n",
        "converter.optimizations = [optimization]\n",
        "tflite_model = converter.convert()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q5PWCDsTC3El"
      },
      "outputs": [],
      "source": [
        "tflite_model_file = 'model.tflite'\n",
        "\n",
        "with open(tflite_model_file, \"wb\") as f:\n",
        "  f.write(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SR6wFcQ1Fglm"
      },
      "source": [
        "# Test the model with TFLite interpreter "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rKcToCBEC-Bu"
      },
      "outputs": [],
      "source": [
        "# Load TFLite model and allocate tensors.\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "interpreter.allocate_tensors()\n",
        "\n",
        "input_index = interpreter.get_input_details()[0][\"index\"]\n",
        "output_index = interpreter.get_output_details()[0][\"index\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E8EpFpIBFkq8"
      },
      "outputs": [],
      "source": [
        "# Gather results for the randomly sampled test images\n",
        "predictions = []\n",
        "test_labels = []\n",
        "test_images = []\n",
        "\n",
        "for img, label in test_batches.take(50):\n",
        "  interpreter.set_tensor(input_index, img)\n",
        "  interpreter.invoke()\n",
        "  predictions.append(interpreter.get_tensor(output_index))\n",
        "  test_labels.append(label[0])\n",
        "  test_images.append(np.array(img))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "kSjTmi05Tyod"
      },
      "outputs": [],
      "source": [
        "#@title Utility functions for plotting\n",
        "# Utilities for plotting\n",
        "\n",
        "def plot_image(i, predictions_array, true_label, img):\n",
        "  predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]\n",
        "  plt.grid(False)\n",
        "  plt.xticks([])\n",
        "  plt.yticks([])\n",
        "  \n",
        "  img = np.squeeze(img)\n",
        "\n",
        "  plt.imshow(img, cmap=plt.cm.binary)\n",
        "\n",
        "  predicted_label = np.argmax(predictions_array)\n",
        "  if predicted_label == true_label.numpy():\n",
        "    color = 'green'\n",
        "  else:\n",
        "    color = 'red'\n",
        "    \n",
        "  plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n",
        "                                100*np.max(predictions_array),\n",
        "                                class_names[true_label]),\n",
        "                                color=color)\n",
        "\n",
        "def plot_value_array(i, predictions_array, true_label):\n",
        "  predictions_array, true_label = predictions_array[i], true_label[i]\n",
        "  plt.grid(False)\n",
        "  plt.xticks(list(range(10)), class_names, rotation='vertical')\n",
        "  plt.yticks([])\n",
        "  thisplot = plt.bar(range(10), predictions_array[0], color=\"#777777\")\n",
        "  plt.ylim([0, 1])\n",
        "  predicted_label = np.argmax(predictions_array[0])\n",
        "\n",
        "  thisplot[predicted_label].set_color('red')\n",
        "  thisplot[true_label].set_color('green')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ZZwg0wFaVXhZ"
      },
      "outputs": [],
      "source": [
        "#@title Visualize the outputs { run: \"auto\" }\n",
        "index = 12 #@param {type:\"slider\", min:1, max:50, step:1}\n",
        "plt.figure(figsize=(6,3))\n",
        "plt.subplot(1,2,1)\n",
        "plot_image(index, predictions, test_labels, test_images)\n",
        "plt.show()\n",
        "plot_value_array(index, predictions, test_labels)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "076bo3FMpRDb"
      },
      "source": [
        "# Download TFLite model and assets\n",
        "\n",
        "**NOTE: You might have to run to the cell below twice**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XsPXqPlgZPjE"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  from google.colab import files\n",
        "\n",
        "  files.download(tflite_model_file)\n",
        "  files.download('labels.txt')\n",
        "except:\n",
        "  pass"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H8t7_jRiz9Vw"
      },
      "source": [
        "# Prepare the test images for download (Optional)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fi09nIps0gBu"
      },
      "outputs": [],
      "source": [
        "!mkdir -p test_images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sF7EZ63J0hZs"
      },
      "outputs": [],
      "source": [
        "from PIL import Image\n",
        "\n",
        "for index, (image, label) in enumerate(test_batches.take(50)):\n",
        "  image = tf.cast(image * 255.0, tf.uint8)\n",
        "  image = tf.squeeze(image).numpy()\n",
        "  pil_image = Image.fromarray(image)\n",
        "  pil_image.save('test_images/{}_{}.jpg'.format(class_names[label[0]].lower(), index))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uM35O-uv0iWS"
      },
      "outputs": [],
      "source": [
        "!ls test_images"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aR20r4qW0jVm"
      },
      "outputs": [],
      "source": [
        "!zip -qq fmnist_test_images.zip -r test_images/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tjk4537X0kWN"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  files.download('fmnist_test_images.zip')\n",
        "except:\n",
        "  pass"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "tflite_c04_exercise_convert_model_to_tflite_solution.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
